import re
import os
import pickle
import random
import numpy as np
import torch


def add_linker_to_data(sequences, distograms, linker=None, linker_len=None, folding=False, max_len=1500):
    """
    Add linker between chains in dimer
    if linker is given, it is discrete linker
    else it is continuous linker

    Args:
        seqences: dict
        distograms: dict
    """
    multimer_sequences = {}
    multimer_lengths = {}
    multimer_distograms = {}

    if linker is not None:
        # for esm
        linker_len = len(linker)
        for name in distograms:
            seq_info = sequences[name]
            if max_len is not None and seq_info['chain1_len'] + linker_len + seq_info['chain2_len'] > max_len:
                continue
            dimer_seq = seq_info['dimer_seq']
            chains = dimer_seq.split(',')
            multimer_sequences[name] = linker.join(chains)
            multimer_lengths[name] = [seq_info['chain1_len'], seq_info['chain2_len']]
            multimer_distograms[name] = distograms[name]
    else:
        # for esmfold
        assert linker_len is not None, 'linker_len should be specified for continuous linker'
        if distograms is not None:
            for name in distograms:
                seq_info = sequences[name]
                if max_len is not None and seq_info['chain1_len'] + linker_len + seq_info['chain2_len'] > max_len:
                    continue
                dimer_seq = seq_info['dimer_seq']
                if folding:
                    multimer_sequences[name] = re.sub(',', ':', dimer_seq)
                else:
                    multimer_sequences[name] = re.sub(',', '', dimer_seq)
                multimer_lengths[name] = [seq_info['chain1_len'], seq_info['chain2_len']]
                multimer_distograms[name] = distograms[name]
        else:
            for name in sequences:
                seq_info = sequences[name]
                if max_len is not None and seq_info['chain1_len'] + linker_len + seq_info['chain2_len'] > max_len:
                    continue
                dimer_seq = seq_info['dimer_seq']
                if folding:
                    multimer_sequences[name] = re.sub(',', ':', dimer_seq)
                else:
                    multimer_sequences[name] = re.sub(',', '', dimer_seq)
                multimer_lengths[name] = [seq_info['chain1_len'], seq_info['chain2_len']]
    data = {'seqs': multimer_sequences, 'distos': multimer_distograms, 'lengths': multimer_lengths}
    return data


def seed_everything(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    if torch.cuda.is_available():
        """reproduce on cuda, it will make the program slower."""
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


def random_sampling(data, n=10, seed=666):
    """
    Random sample a subset from train and valid for few-shot learning
    """
    sequences, distograms, lengths = data['seqs'], data['distos'], data['lengths']
    names = list(distograms.keys())
    np.random.seed(seed)
    samples = np.random.choice(names, size=min(n, len(names)), replace=False)
    seqs = {x: sequences[x] for x in samples}
    distos = {x: distograms[x] for x in samples}
    lengths = {x: lengths[x] for x in samples}
    new_data = {'seqs': seqs, 'distos': distos, 'lengths': lengths}
    return new_data


def load_esm2_seq_representation(data_dir, data_mode, linker_len, crop=True, crop_size=200):
    
    if data_mode == 'train' and crop:
        file_name = data_mode+'_crop'+str(crop_size)
    else:
        file_name = data_mode
    data_path = os.path.join(data_dir, 'esm2_3b_feats/'+str(linker_len)+'/'+file_name+'_seq_rep.pickle')
    with open(data_path, mode='rb') as f:
        seq_reps = pickle.load(f) 
    return seq_reps

